use std::ops::Range;

use ahash::AHashMap;
use arena::{Arena, IdxRange};
use bitset::{BitSet, HybridBitSet};
use mir::builder::{InsertBuilder, InstBuilder, InstInserterBase};
use mir::{
    Block, Function, Inst, InstructionData, Opcode, SourceLoc, Unknown, Value, F_LOG10_E, F_ONE,
    F_TWO, F_ZERO,
};
use stdx::iter::zip;
use stdx::packed_option::{PackedOption, ReservedValue};

use crate::intern::{Derivative, DerivativeIntern};
use crate::live_derivatives::LiveDerivatives;

#[cfg(test)]
mod tests;

pub fn build_derivatives(
    func: &mut Function,
    intern: &mut DerivativeIntern,
    live_derivatives: &LiveDerivatives,
    post_order: &[Block],
) -> AHashMap<(Value, Unknown), Value> {
    let derivative_values: AHashMap<(Value, Unknown), Value> =
        intern.unknowns.iter_enumerated().map(|(unknown, &val)| ((val, unknown), F_ONE)).collect();

    let mut known_values = BitSet::new_empty(func.dfg.num_values());
    for val in func.dfg.values() {
        if func.dfg.value_def(val).inst().is_none() {
            known_values.insert(val);
        }
    }

    let mut builder = DerivativeBuilder {
        func,
        live_derivatives,
        intern,
        derivative_values,
        known_values,
        dst: (0u32.into(), SourceLoc::new(0)),
        cyclical_phis: Vec::with_capacity(64),
        new_block: None,
    };

    builder.run(post_order);
    builder.derivative_values
}

type CacheData = [PackedOption<Value>; 3];

pub(crate) struct DerivativeBuilder<'a, 'u> {
    func: &'a mut Function,

    live_derivatives: &'a LiveDerivatives,
    intern: &'a mut DerivativeIntern<'u>,

    derivative_values: AHashMap<(Value, Unknown), Value>,
    known_values: BitSet<Value>,
    dst: (Inst, SourceLoc),
    new_block: Option<Block>,

    cyclical_phis: Vec<(Inst, Derivative)>,
}

impl<'f> InstInserterBase<'f> for &'f mut DerivativeBuilder<'_, '_> {
    fn data_flow_graph(&self) -> &mir::DataFlowGraph {
        &self.func.dfg
    }

    fn data_flow_graph_mut(&mut self) -> &mut mir::DataFlowGraph {
        &mut self.func.dfg
    }

    fn insert_built_inst(self, inst: Inst) -> &'f mut mir::DataFlowGraph {
        if let Some(new_block) = self.new_block.take() {
            if let Some(first_inst) = self.func.layout.first_inst(new_block) {
                self.func.layout.prepend_inst(inst, first_inst);
            } else {
                self.func.layout.append_inst_to_bb(inst, new_block);
            }
        } else {
            self.func.layout.append_inst(inst, self.dst.0);
        }

        if self.func.srclocs.len() <= inst.into() {
            self.func.srclocs.resize(inst.into(), SourceLoc::default());
            self.func.srclocs.push(self.dst.1);
        } else {
            self.func.srclocs[inst] = self.dst.1;
        }
        self.known_values.ensure(self.func.dfg.num_values());
        for val in self.func.dfg.inst_results(inst) {
            self.known_values.insert(*val);
        }
        self.dst.0 = inst;
        &mut self.func.dfg
    }
}

impl<'a, 'u> DerivativeBuilder<'a, 'u> {
    /// generates all derivatives and stores them in `self.derivative_values`
    ///
    /// Phi nodes are only cloned so that they have a return value that can be stored in
    /// `self.derivative_values`.
    ///
    /// Phis may form circular references and can therefore not be generated by a simple reverse
    /// postorder walk. Instead their edges are updated in the `build_phis` function
    pub fn run(&mut self, post_order: &[Block]) {
        let mut cache = BuilderCache::default();

        for bb in post_order.iter().rev() {
            let mut cursor = self.func.layout.block_inst_cursor(*bb);
            while let Some(inst) = cursor.next(&self.func.layout) {
                let mut srcloc = self.func.srclocs.get(inst).copied().unwrap_or_default();
                srcloc.0 *= -1;
                self.dst = (inst, srcloc);
                self.build_inst_derivatives(&mut cache);
            }
        }

        // populate phis with derivatives (now all values are either known/phi dummys)
        for (phi, derivative) in &self.cyclical_phis {
            self.func.dfg.zap_inst(*phi);
            for arg in self.func.dfg.instr_args_mut(*phi) {
                *arg =
                    Self::derivative_of_(self.intern, &self.derivative_values, *arg, *derivative);
            }
            self.func.dfg.update_inst_uses(*phi)
        }
    }

    fn ddx(
        &mut self,
        arg: Value,
        pos_unknowns: &HybridBitSet<Unknown>,
        neg_unknowns: &HybridBitSet<Unknown>,
        higher_order_derivative: Option<Derivative>,
    ) -> Value {
        let mut derivative = F_ZERO;
        for next_unknown in pos_unknowns.iter() {
            let val = if let Some(unknown) = higher_order_derivative {
                let unknown = self.intern.raise_order(unknown, next_unknown);
                self.derivative_of(arg, unknown)
            } else {
                self.derivative_of_1(arg, next_unknown)
            };

            if val == F_ZERO {
                continue;
            }

            if derivative == F_ZERO {
                derivative = val
            } else {
                derivative = self.ins().fadd(derivative, val)
            }
        }

        for next_unknown in neg_unknowns.iter() {
            let val = if let Some(unknown) = higher_order_derivative {
                let unknown = self.intern.raise_order(unknown, next_unknown);
                self.derivative_of(arg, unknown)
            } else {
                self.derivative_of_1(arg, next_unknown)
            };

            if val == F_ZERO {
                continue;
            }

            if derivative == F_ZERO {
                derivative = self.ins().fneg(val);
            } else {
                derivative = self.ins().fsub(derivative, val)
            }
        }

        derivative
    }

    fn build_inst_derivatives(&mut self, bcache: &mut BuilderCache) {
        let derivatives = self.live_derivatives.of_inst(self.dst.0);

        let inst = self.dst.0;
        match self.func.dfg.insts[inst].clone() {
            // ddx calls just get replaced with the appropriate derivatives
            InstructionData::Call { func_ref, args } => {
                if let Some((pos_unknowns, neg_unknowns)) = self.intern.ddx_calls.get(&func_ref) {
                    let inst = self.dst.0;
                    let arg = args.as_slice(&self.func.dfg.insts.value_lists)[0];
                    let res = self.ddx(arg, pos_unknowns, neg_unknowns, None);

                    // replace call with calculated derivative
                    let old = self.func.dfg.first_result(inst);
                    self.func.dfg.replace_uses(old, res);
                    self.func.dfg.zap_inst(inst);
                    self.func.layout.remove_inst(inst);

                    let higher_order_derivatives =
                        self.live_derivatives.compute_inst(inst, self.func, self.intern);
                    for derivative in higher_order_derivatives.iter() {
                        let ddx_val = self.ddx(arg, pos_unknowns, neg_unknowns, Some(derivative));

                        let prev_order = match self.intern.previous_order(derivative) {
                            Some(val) => self.derivative_of(res, val),
                            None => res, // 0th order is the original result
                        };
                        let unknown = self.intern.get_unknown(derivative);
                        self.insert_derivative(prev_order, unknown, ddx_val);
                    }

                    debug_assert!(self.live_derivatives.conversions.get(&inst).is_none());
                }
            }

            // place dummy values for phis to break cycels
            // arguments are populated with derivatives later
            InstructionData::PhiNode(phi) => {
                if let Some(derivatives) = derivatives {
                    let res = self.func.dfg.first_result(self.dst.0);
                    let is_cyclical = self
                        .func
                        .dfg
                        .phi_edges(&phi)
                        .any(|(_, val)| !self.known_values.contains(val));

                    if is_cyclical {
                        // we do not calculate phis just yet because it requires the value of a backwards edges.
                        // Instead we just copy the phi and delay until all other derivatives are
                        // done
                        for derivative in derivatives.iter() {
                            let prev_order = match self.intern.previous_order(derivative) {
                                Some(val) => self.derivative_of(res, val),
                                None => res, // 0th order is the original result
                            };
                            let unknown = self.intern.get_unknown(derivative);

                            let edges: Vec<_> = self.func.dfg.phi_edges(&phi).collect();
                            let val = self.ins().phi(&edges);
                            self.derivative_values.insert((prev_order, unknown), val);

                            self.cyclical_phis.push((self.dst.0, derivative))
                        }
                    } else {
                        for derivative in derivatives.iter() {
                            let prev_order = self.prev_order_derivative_of(res, derivative);
                            let unknown = self.intern.get_unknown(derivative);

                            let edges: Vec<_> = self
                                .func
                                .dfg
                                .phi_edges(&phi)
                                .map(|(bb, val)| (bb, self.derivative_of(val, derivative)))
                                .collect();

                            if edges.iter().all(|(_, val)| *val == edges[0].1) {
                                self.insert_derivative(prev_order, unknown, edges[0].1);
                            } else {
                                let val = self.ins().phi(&edges);
                                self.derivative_values.insert((prev_order, unknown), val);
                            }
                        }
                    }
                }

                self.insert_conversions(inst);
            }
            InstructionData::Binary { opcode: Opcode::Pow, args: [base, _] } => {
                if let Some(derivatives) = derivatives {
                    let inst = self.dst.0;
                    let is_base_zero = self.ins().feq(base, F_ZERO);

                    let old_block =
                        self.func.layout.inst_block(inst).expect("instruction is attached");
                    let new_block = self.func.layout.make_block();

                    if let Some(next_inst) = self.func.layout.next_inst(self.dst.0) {
                        self.func.split_block(new_block, next_inst);
                    } else {
                        self.func.layout.append_block(new_block);
                    };
                    let calculate_derivative_block = self.func.layout.make_block();
                    self.func.layout.insert_block_after(calculate_derivative_block, old_block);
                    self.ins().br(is_base_zero, new_block, calculate_derivative_block);

                    // insert into the newly created block
                    self.new_block = Some(calculate_derivative_block);
                    self.dst.0 = inst;
                    self.build_normal_inst_derivatives(bcache, derivatives);
                    self.ins().jump(new_block);

                    self.new_block = Some(new_block);
                    let res = self.func.dfg.first_result(inst);

                    // replace the calculates derivatives with phis that return
                    // 0 in case that base is zero to ensure numerical stability
                    // requires collecting derivatives into a temporary vector
                    // because we are going to overwrite derivatives that are required
                    // for looking up higher order derivatives
                    let new_derivatives: Vec<_> = derivatives
                        .iter()
                        .filter_map(|derivative| {
                            let val = self.derivative_of(res, derivative);
                            if val == F_ZERO {
                                return None;
                            }

                            let checked_val = self
                                .ins()
                                .phi(&[(old_block, F_ZERO), (calculate_derivative_block, val)]);
                            Some((checked_val, derivative))
                        })
                        .collect();

                    for (val, derivative) in new_derivatives {
                        let prev_order = self.prev_order_derivative_of(res, derivative);
                        let unknown = self.intern.get_unknown(derivative);
                        self.derivative_values.insert((prev_order, unknown), val);
                    }

                    self.insert_conversions(inst);
                    self.new_block.take();
                }
            }
            _ => {
                if let Some(derivatives) = derivatives {
                    self.build_normal_inst_derivatives(bcache, derivatives);
                    self.insert_conversions(inst);
                }
            }
        }
    }

    fn build_normal_inst_derivatives(
        &mut self,
        bcache: &mut BuilderCache,
        derivatives: &'a HybridBitSet<Derivative>,
    ) {
        // add the original instruction
        bcache.resolved_derivatives.insert(None, ResolvedDerivative::root_instr(self.dst.0));

        for derivative in derivatives.iter() {
            let prev_order = self.intern.previous_order(derivative);
            let base = self.intern.get_unknown(derivative);

            let origin = bcache.resolved_derivatives[&prev_order].clone();
            let cache = self.ensure_cache(prev_order, &origin, bcache);

            let inst_start = self.func.dfg.num_insts().into();

            for (inst, cache_data_i) in zip(origin.instructions(), cache.data) {
                if !self.func.dfg.has_results(inst) {
                    continue;
                }
                let res = self.func.dfg.first_result(inst);
                if !self.derivative_values.contains_key(&(res, base)) {
                    self.inst_derivative(inst, base, bcache.cache_data[cache_data_i]);
                }
            }

            let inst_end = self.func.dfg.num_insts().into();

            bcache.resolved_derivatives.insert(
                Some(derivative),
                ResolvedDerivative { instrs: inst_start..inst_end, cache_instrs: cache.instrs },
            );
        }

        bcache.clear();
    }

    fn insert_conversions(&mut self, inst: Inst) {
        if let Some(conversion) = self.live_derivatives.conversions.get(&inst) {
            for chain_rule in conversion.iter().rev() {
                let outer_derivative =
                    self.derivative_of(chain_rule.val, chain_rule.outer_derivative);
                if outer_derivative == F_ZERO {
                    continue;
                }

                let inner_derivative = self
                    .derivative_of(chain_rule.inner_derivative.0, chain_rule.inner_derivative.1);

                if inner_derivative == F_ZERO {
                    continue;
                }

                let prev_order =
                    self.prev_order_derivative_of(chain_rule.val, chain_rule.dst_derivative);
                debug_assert_ne!(prev_order, F_ZERO);
                let unknown = self.intern.get_unknown(chain_rule.dst_derivative);
                let val = self.ins().fmul(inner_derivative, outer_derivative);
                self.derivative_values.insert((prev_order, unknown), val);
            }
        }
    }

    fn ensure_cache(
        &mut self,
        prev_order: Option<Derivative>,
        prev_order_instr: &ResolvedDerivative,
        bcache: &mut BuilderCache,
    ) -> CacheInfo {
        bcache
            .derivative_cache
            .entry(prev_order)
            .or_insert_with(|| {
                let instr_start = self.func.dfg.num_insts().into();
                let data_start = bcache.cache_data.len().into();

                let new_cache_data =
                    prev_order_instr.instructions().map(|inst| self.inst_cache(inst));
                bcache.cache_data.extend(new_cache_data);

                let instr_end = self.func.dfg.num_insts().into();
                let data_end = bcache.cache_data.len().into();

                CacheInfo {
                    instrs: instr_start..instr_end,
                    data: IdxRange::new(data_start..data_end),
                }
            })
            .to_owned()
    }

    fn ins(&mut self) -> InsertBuilder<&mut DerivativeBuilder<'a, 'u>> {
        InsertBuilder::new(self)
    }

    fn derivative_of_1(&self, val: Value, unknown: Unknown) -> Value {
        self.derivative_values.get(&(val, unknown)).copied().unwrap_or(F_ZERO)
    }

    fn prev_order_derivative_of(&self, val: Value, derivative: Derivative) -> Value {
        match self.intern.previous_order(derivative) {
            Some(prev_order) => {
                Self::derivative_of_(self.intern, &self.derivative_values, val, prev_order)
            }
            None => val,
        }
    }

    fn derivative_of(&self, val: Value, derivative: Derivative) -> Value {
        Self::derivative_of_(self.intern, &self.derivative_values, val, derivative)
    }

    fn derivative_of_(
        intern: &DerivativeIntern,
        derivative_values: &AHashMap<(Value, Unknown), Value>,
        mut val: Value,
        derivative: Derivative,
    ) -> Value {
        for unknown in intern.unknowns_rev(derivative) {
            if let Some(derivative) = derivative_values.get(&(val, unknown)) {
                val = *derivative;
            } else {
                return F_ZERO;
            }
        }

        val
    }

    fn insert_derivative(&mut self, original: Value, unknown: Unknown, val: Value) {
        if val == F_ZERO {
            return;
        }
        let old = self.derivative_values.insert((original, unknown), val);
        if cfg!(debug_assertions) {
            if let Some(old) = old {
                let original_inst = self.func.dfg.value_def(original).unwrap_inst();
                let original_inst = self.func.dfg.display_inst(original_inst);
                let old_inst = self.func.dfg.value_def(old).unwrap_inst();
                let old_inst = self.func.dfg.display_inst(old_inst);
                let new_inst = self.func.dfg.value_def(val).unwrap_inst();
                let new_inst = self.func.dfg.display_inst(new_inst);
                panic!("derivative of {original} by {unknown:?} generated twice: {old} {val}\norg: {original_inst}\n{new_inst}\n{old_inst}")
            }
        }
    }

    fn inst_cache(&mut self, inst: Inst) -> CacheData {
        let mut cache = [None.into(), None.into(), None.into()];

        let op = self.func.dfg.insts[inst].opcode();
        let args = self.func.dfg.instr_args(inst);
        let arg0 = args.get(0).copied().unwrap_or_else(Value::reserved_value);
        let arg1 = args.get(1).copied().unwrap_or_else(Value::reserved_value);
        let res = self.func.dfg.first_result(inst);

        let val = match op {
            // Opcode::Idiv => {
            //     let val = self.ins().imul(arg1, arg1);
            //     self.ins().ifcast(val)
            // }
            Opcode::Fdiv => self.ins().fmul(arg1, arg1),

            // Technically not required but makes code look nicer..
            // exp(x) -> exp(x)
            Opcode::Exp => res,

            // hypot(x,y) -> (x' + y')/2hypot(x,y)
            // sqrt(x) -> 1/2sqrt(x)
            Opcode::Hypot | Opcode::Sqrt => self.ins().fmul(F_TWO, res),
            // ln(x) -> 1/x
            Opcode::Ln => arg0,
            // log(x) -> log(e)/x
            Opcode::Log => self.ins().fdiv(F_LOG10_E, arg0),
            // sin(x) -> cos(x)
            Opcode::Sin => self.ins().cos(arg0),
            // cos(x) -> -sin(x)
            Opcode::Cos => {
                let sin = self.ins().sin(arg0);
                self.ins().fneg(sin)
            }
            // tan(x) -> 1 + tan^2(x)
            Opcode::Tan => {
                let tan_2 = self.ins().fmul(res, res);
                self.ins().fadd(F_ONE, tan_2)
            }

            // asin(x) -> 1/sqrt(1-x^2)
            Opcode::Asin => {
                // sqrt(1 - x^2)
                let arg_squared = self.ins().fmul(arg0, arg0);
                let sqrt_arg = self.ins().fsub(F_ONE, arg_squared);
                self.ins().sqrt(sqrt_arg)
            }
            // acos(x) -> -1/sqrt(1-x^2)
            Opcode::Acos => {
                // sqrt(1 - x^2)
                let arg_squared = self.ins().fmul(arg0, arg0);
                let sqrt_arg = self.ins().fsub(F_ONE, arg_squared);
                let sqrt = self.ins().sqrt(sqrt_arg);
                self.ins().fneg(sqrt)
            }

            // arctan(x) -> 1/(1 + x^2)
            Opcode::Atan => {
                // 1 + x^2
                let arg_squared = self.ins().fmul(arg0, arg0);
                self.ins().fadd(F_ONE, arg_squared)
            }
            // arctan2(x,y) => (x'*y - y'*x)/(x^2+y^2)
            Opcode::Atan2 => {
                let lhs_squared = self.ins().fmul(arg0, arg0);
                let rhs_squared = self.ins().fmul(arg1, arg1);
                let bot = self.ins().fadd(lhs_squared, rhs_squared);

                cache[2] = bot.into();
                cache[1] = arg0.into();
                arg1
            }

            // sinh(x) -> cosh(x)
            Opcode::Sinh => self.ins().cosh(arg0),
            // cosh(x) -> sinh(x)
            Opcode::Cosh => self.ins().sinh(arg0),

            // tanh(x) -> 1 - tanh^2(x)
            Opcode::Tanh => {
                let tan_2 = self.ins().fmul(res, res);
                self.ins().fsub(F_ONE, tan_2)
            }
            // acsinh(x) -> 1/sqrt(x^2 + 1)
            Opcode::Asinh => {
                // sqrt(1 + x^2)
                let arg_squared = self.ins().fmul(arg0, arg0);
                let sqrt_arg = self.ins().fadd(F_ONE, arg_squared);
                self.ins().sqrt(sqrt_arg)
            }
            // acosh(x) -> 1/sqrt(x^2 - 1)
            Opcode::Acosh => {
                // sqrt(x^2 - 1)
                let arg_squared = self.ins().fmul(arg0, arg0);
                let sqrt_arg = self.ins().fsub(arg_squared, F_ONE);
                self.ins().sqrt(sqrt_arg)
            }

            // arctanh(x) -> 1/(1-x^2)
            Opcode::Atanh => {
                // 1 - x^2
                let arg_squared = self.ins().fmul(arg0, arg0);
                self.ins().fsub(F_ONE, arg_squared)
            }

            // // x << y = x*pow(2,y)-> ln(2) * x * y'* pow(2,y)  + x' * pow(2,y)
            // // = ln(2) * y' * x<<y + x' * 1<<y
            // Opcode::Ishl => {
            //     let res = self.ins().ifcast(res);
            //     let lhs_cache = self.ins().fmul(F_LN2, res);

            //     let rhs_cache = self.ins().ishl(ONE, arg1);
            //     let rhs_cache = self.ins().ifcast(rhs_cache);
            //     cache[1] = rhs_cache.into();

            //     lhs_cache
            // }
            // // x >> y = x*pow(2,-y)-> -ln(2) * x * y'* pow(2,-y)  + x' * pow(2,-y)
            // // = -ln(2) * y' * x>>y + x' * 1>>y
            // Opcode::Ishr => {
            //     let res = self.ins().ifcast(res);
            //     let lhs_cache = self.ins().fmul(F_LN2_N, res);

            //     let rhs_cache = self.ins().ishr(ONE, arg1);
            //     let rhs_cache = self.ins().ifcast(rhs_cache);
            //     cache[1] = rhs_cache.into();

            //     lhs_cache
            // }

            // pow(x,y) -> pow(x,y)*(x'*y/x + ln(x) * y')
            Opcode::Pow => {
                let arg1_div_arg0 = if arg1 == arg0 { F_ONE } else { self.ins().fdiv(arg1, arg0) };
                let ln_x = self.ins().ln(arg0);
                cache[2] = res.into();
                cache[1] = ln_x.into();
                arg1_div_arg0
            }
            _ => return cache,
        };

        cache[0] = val.into();

        cache
    }

    fn inst_derivative(&mut self, inst: Inst, unknown: Unknown, cache: CacheData) {
        let res = self.func.dfg.first_result(inst);
        let op = self.func.dfg.insts[inst].opcode();

        let args = self.func.dfg.instr_args(inst);
        let arg0 = args.get(0).copied().unwrap_or_else(Value::reserved_value);
        let arg1 = args.get(1).copied().unwrap_or_else(Value::reserved_value);
        let arg_derivative = |sel: &mut DerivativeBuilder, i| {
            sel.derivative_of_1(sel.func.dfg.instr_args(inst)[i], unknown)
        };

        let gen_mul_derivative =
            |sel: &mut DerivativeBuilder, mut lhs: Value, mut rhs: Value, cast: bool| {
                let drhs = arg_derivative(sel, 1);
                let dlhs = arg_derivative(sel, 0);
                let res = if cast {
                    // TODO cache
                    // TODO lazy
                    lhs = sel.ins().ifcast(lhs);
                    rhs = sel.ins().ifcast(rhs);
                    sel.ins().ifcast(res)
                } else {
                    res
                };

                let sum1 = if dlhs == F_ZERO {
                    F_ZERO
                } else if dlhs == F_ONE {
                    rhs
                } else if dlhs == lhs {
                    res
                } else {
                    sel.simplified_mul(lhs, dlhs, rhs, res)
                };
                let sum2 = if drhs == F_ZERO {
                    return sum1;
                } else if drhs == F_ONE {
                    lhs
                } else if drhs == rhs {
                    res
                } else {
                    sel.simplified_mul(rhs, drhs, lhs, res)
                };

                if dlhs == F_ZERO {
                    return sum2;
                }
                sel.ins().fadd(sum1, sum2)
            };

        // (f/g)' -> (f'*g - g' *f) / g^2 = f'/g - g'*f/g^2
        let gen_div_derivative =
            |sel: &mut DerivativeBuilder, mut lhs: Value, mut rhs: Value, cast: bool| {
                let dlhs = arg_derivative(sel, 0);
                let drhs = arg_derivative(sel, 1);
                let res = if cast {
                    lhs = sel.ins().ifcast(lhs);
                    rhs = sel.ins().ifcast(rhs);
                    sel.ins().ifcast(res)
                } else {
                    res
                };

                // f'/g
                let sum1 = if dlhs == F_ZERO {
                    F_ZERO
                } else if dlhs == lhs {
                    res
                } else {
                    sel.ins().fdiv(dlhs, rhs)
                };

                // f*g'/g^2
                let top = if drhs == F_ZERO {
                    return sum1;
                } else if drhs == F_ONE {
                    lhs
                } else if drhs == rhs {
                    if sum1 == res {
                        return F_ZERO;
                    }
                    res
                } else {
                    sel.ins().fmul(drhs, lhs)
                };
                let bot = cache[0].unwrap_unchecked();
                let sum2 = sel.ins().fdiv(top, bot);

                sel.ins().fsub(sum1, sum2)
            };

        let val = match op {
            Opcode::Call
            | Opcode::Ineg
            | Opcode::Iadd
            | Opcode::Isub
            | Opcode::Imul
            | Opcode::Idiv
            | Opcode::Ishl
            | Opcode::Ishr
            | Opcode::IFcast
            | Opcode::BIcast
            | Opcode::IBcast
            | Opcode::FBcast
            | Opcode::BFcast
            | Opcode::FIcast
            | Opcode::Irem
            | Opcode::Inot
            | Opcode::Ixor
            | Opcode::Iand
            | Opcode::Ior
            | Opcode::Clog2
            | Opcode::Frem
            | Opcode::Floor
            | Opcode::Ceil
            | Opcode::Bnot
            | Opcode::Ilt
            | Opcode::Igt
            | Opcode::Flt
            | Opcode::Fgt
            | Opcode::Ile
            | Opcode::Ige
            | Opcode::Fle
            | Opcode::Fge
            | Opcode::Ieq
            | Opcode::Feq
            | Opcode::Seq
            | Opcode::Beq
            | Opcode::Ine
            | Opcode::Fne
            | Opcode::Sne
            | Opcode::Bne
                // zero no need to store the derivative
                => return,

            Opcode::Fneg  => {
                let arg = arg_derivative(self, 0);
                self.ins().fneg(arg)
            }

            Opcode::OptBarrier
                => arg_derivative(self,0),

            Opcode::Fadd => {
                let dlhs = arg_derivative(self, 0);
                let drhs = arg_derivative(self, 1);
                self.ins().fadd(dlhs, drhs)
            }

            Opcode::Fsub => {
                let dlhs = arg_derivative(self, 0);
                let drhs = arg_derivative(self, 1);
                self.ins().fsub(dlhs, drhs)
            }

            Opcode::Fmul => gen_mul_derivative(self,  arg0,arg1, false),
            Opcode::Fdiv => gen_div_derivative(self, arg0, arg1, false),

            Opcode::Exp
            | Opcode::Log
            | Opcode::Sin
            | Opcode::Cos
            | Opcode::Sinh
            | Opcode::Cosh
            | Opcode::Tan
            | Opcode::Tanh => {
                let darg = arg_derivative(self, 0);
                match darg{
                        F_ZERO => return,
                        F_ONE => cache[0].unwrap_unchecked(),
                        _ => self.ins().fmul(darg, cache[0].unwrap_unchecked()),
                }
            }

            Opcode::Ln
            |Opcode::Sqrt
            | Opcode::Asin
            | Opcode::Acos
            | Opcode::Atan
            | Opcode::Asinh
            | Opcode::Acosh
            | Opcode::Atanh =>{
                    let darg = arg_derivative(self, 0);
                    if darg == F_ZERO{
                        return
                    }
                    self.ins().fdiv(darg, cache[0].unwrap_unchecked())
                }

            Opcode::Pow | Opcode::Atan2 => {
                let dlhs = arg_derivative(self, 0);
                let drhs = arg_derivative(self, 1);

                let sum1 = if dlhs == F_ZERO{
                    F_ZERO
                }else if dlhs == F_ONE{
                   cache[0].unwrap_unchecked()
                }else if cache[0].unwrap_unchecked() == F_ONE{
                    dlhs
                }
                else{
                    self.ins().fmul(dlhs, cache[0].unwrap_unchecked())
                };

                let inner = if drhs == F_ZERO{
                    sum1
                }else{
                    let sum2 = if drhs == F_ONE{
                       cache[1].unwrap_unchecked()
                    }else{
                        self.ins().fmul(drhs, cache[1].unwrap_unchecked())
                    };

                    if sum1 == F_ZERO{
                        sum2
                    }else{
                        self.ins().fadd(sum1, sum2)
                    }
                };
                if inner == F_ZERO{
                    F_ZERO
                }else{
                    self.ins().fmul(inner, cache[2].unwrap_unchecked())
                }
            }

            Opcode::Hypot => {
                let dlhs = arg_derivative(self, 0);
                let drhs = arg_derivative(self, 1);
                let sum = self.ins().fadd(dlhs, drhs);
                self.ins().fdiv(sum, cache[0].unwrap_unchecked())
            }
            Opcode::Br | Opcode::Jmp | Opcode::Phi  => unreachable!(),
        };

        self.insert_derivative(res, unknown, val)
    }

    fn simplified_mul(&mut self, lhs: Value, dlhs: Value, rhs: Value, res: Value) -> Value {
        // make sure that x = A * exp(C) is derived as
        // A' * exp(C) + x*C' instead of A'*B + A*(exp(C)*C')
        // which is hard to optimize correctly
        if let Some(inst) = self.func.dfg.value_def(dlhs).inst() {
            if let InstructionData::Binary { opcode: Opcode::Fmul, args } =
                self.func.dfg.insts[inst]
            {
                if args[0] == lhs {
                    return self.ins().fmul(res, args[1]);
                }
                if args[1] == lhs {
                    return self.ins().fmul(res, args[0]);
                }
            }
        }
        self.ins().fmul(dlhs, rhs)
    }
}

#[derive(Debug, PartialEq, Clone, Eq, Hash)]
pub struct CacheInfo {
    instrs: Range<Inst>,
    data: IdxRange<CacheData>,
}

#[derive(Debug, PartialEq, Clone, Eq, Hash)]
pub struct ResolvedDerivative {
    instrs: Range<Inst>,
    cache_instrs: Range<Inst>,
}

impl ResolvedDerivative {
    fn root_instr(pos: Inst) -> ResolvedDerivative {
        // The original instruction (so something the user typed) never has a cache and is always the first
        // instruction.
        ResolvedDerivative {
            instrs: pos..Inst::from(u32::from(pos) + 1),
            cache_instrs: pos..Inst::from(u32::from(pos)),
        }
    }

    fn instructions(&self) -> impl Iterator<Item = Inst> {
        let instrs: Range<u32> = self.instrs.start.into()..self.instrs.end.into();
        let cache_instrs: Range<u32> = self.cache_instrs.start.into()..self.cache_instrs.end.into();
        cache_instrs.chain(instrs).map(Inst::from)
    }
}

#[derive(Default)]
pub struct BuilderCache {
    resolved_derivatives: AHashMap<Option<Derivative>, ResolvedDerivative>,
    cache_data: Arena<CacheData>,
    derivative_cache: AHashMap<Option<Derivative>, CacheInfo>,
}

impl BuilderCache {
    fn clear(&mut self) {
        self.cache_data.clear();
        self.derivative_cache.clear();
        self.resolved_derivatives.clear();
    }
}
